
#ifndef lenet_comp_h
#define lenet_comp_h

#include "public.h"



float relu(float x)
{
    return max(0.f, x);
}

float relugrad(float y)
{
    return float(y > 0);
}


#define CONVOLUTION_FORWARD(input, output, weight, bias, id, oW, oH, wW, wH, iN, oN)\
{                                                                                   \
    if((id) < oN * oW * oH)                                                         \
    {                                                                               \
        uint y = (id) / oW;                                                         \
        uint z = (id) - y * oW;                                                     \
        uint x = y / oH;                                                            \
        y -= x * oH;                                                                \
        float res = 0.f;                                                            \
        FOREACH(i, iN)                                                              \
            FOREACH(w0, wH)                                                         \
                FOREACH(w1, wW)                                                     \
                    res += (input)[i][y + w0][z + w1] * (weight)[i][x][w0][w1];     \
        (output)[x][y][z] = relu(res + (bias)[x]);                                  \
    }                                                                               \
}

#define DOWNSAMP_MAX_FORWARD(input, output, id, oW, oH, kW, kH, N)  \
{                                                                   \
    if((id) < N * oW * oH)                                          \
    {                                                               \
        uint y = (id) / oW;                                         \
        uint z = (id) - y * oW;                                     \
        uint x = y / oH;                                            \
        y -= x * oH;                                                \
        float res = -1.f / 0.f;                                     \
        FOREACH(w0, kH)                                             \
            FOREACH(w1, kW)                                         \
                res = max(res, input[x][y * kH + w0][z * kW + w1]); \
        (output)[x][y][z] = res;                                    \
    }                                                               \
}

#define DOTPRODUCT_FORWARD(input, output, weight, bias, id, iW, iH, iN, oN) \
{                                                                           \
    if((id) < oN)                                                           \
    {                                                                       \
        float res = 0.f;                                                    \
        for(int i = 0, x = 0; i < iN; ++i)                                  \
            for(int j = 0; j < iH; ++j)                                     \
                for(int k = 0; k < iW; ++k, ++x)                            \
                    res += (input)[i][j][k] * (weight)[x][id];              \
        (output)[id] = relu(res + (bias)[id]);                              \
    }                                                                       \
}

#define CONVOLUTION_BACKWARD(input, ierr, oerr, weight, id, iW, iH, wW, wH, iN, oN)     \
{                                                                                       \
    if ((id) < iN * iW * iH)                                                            \
    {                                                                                   \
        uint y = (id) / iW;                                                             \
        uint z = (id) - y * iW;                                                         \
        uint x = y / iH;                                                                \
        y -= x * iH;                                                                    \
        float res = 0.f;                                                                \
        FOREACH(o, oN)                                                                  \
            FOREACH(w0, wH)                                                             \
                FOREACH(w1, wW)                                                         \
                    if (w0 <= y && y - w0 + wH <= iH && w1 <= z && z - w1 + wW <= iW)   \
                        res += (oerr)[o][y - w0][z - w1] * (weight)[x][o][w0][w1];      \
        (ierr)[x][y][z] = res * relugrad((input)[x][y][z]);                             \
    }                                                                                   \
}

#define DOWNSAMP_MAX_BACKWARD(input, ierr, oerr, id, oW, oH, wW, wH, iN, oN)\
{                                                                           \
    if ((id) < oN * oW * oH)                                                \
    {                                                                       \
        uint y = (id) / oW;                                                 \
        uint z = (id) - y * oW;                                             \
        uint x = y / oH;                                                    \
        y -= x * oH;                                                        \
        const float res = (oerr)[x][y][z];                                  \
        y *= wH;                                                            \
        z *= wW;                                                            \
        uint y0 = 0, z0 = 0;                                                \
        FOREACH(w0, wH)                                                     \
            FOREACH(w1, wW)                                                 \
            {                                                               \
                (ierr)[x][y + w0][z + w1] = 0.f;                            \
                if((input)[x][y + y0][z + z0] < (input)[x][y + w0][z + w1]) \
                {                                                           \
                    y0 = w0;                                                \
                    z0 = w1;                                                \
                }                                                           \
            }                                                               \
        (ierr)[x][y + y0][z + z0] = res;                                    \
    }                                                                       \
}

#define DOT_PRODUCT_BACKWARD(input, ierr, oerr, weight, id, iW, iH, iN, oN)	\
{																			\
    if ((id) < iN * iW * iH)                                                \
    {                                                                       \
        uint y = (id) / iW;                                                 \
        uint z = (id) - y * iW;                                             \
        uint x = y / iH;                                                    \
        y -= x * iH;                                                        \
        float res = 0.f;                                                    \
        FOREACH(o, oN)                                                      \
            res += (oerr)[o] * (weight)[id][o];                             \
        (ierr)[x][y][z] = res * relugrad((input)[x][y][z]);                 \
    }                                                                       \
}


#define CONVOLUTION_DELTA(input, oerr, weight, bias, id, oW, oH, wW, wH, iN, oN)\
{                                                                               \
    if((id) < iN * oN * wH * wW)                                                \
    {                                                                           \
        uint z = (id) / wW;                                                     \
        uint w = (id) - z * wW;                                                 \
        uint y = z / wH;                                                        \
        z -= y * wH;                                                            \
        uint x = y / oN;                                                        \
        y -= x * oN;                                                            \
        float deltaW = 0.f, deltaB = 0.f;                                       \
        FOREACH(o0, oH)                                                         \
            FOREACH(o1, oW)                                                     \
            {                                                                   \
                float err = (oerr)[y][o0][o1];                                  \
                deltaW += (input)[x][o0 + z][o1 + w] * err;                     \
                deltaB += err;                                                  \
            }                                                                   \
        (weight)[x][y][z][w] = deltaW;                                          \
        if(x == 0 && z == 0 && w == 0)                                          \
            (bias)[y] = deltaB;                                                 \
    }                                                                           \
}

#define DOT_PRODUCT_DELTA(input, oerr, weight, bias, id, iW, iH, iN, oN)\
{                                                                       \
    if((id) < iN * oN)                                                  \
    {                                                                   \
        uint x = (id) / oN;                                             \
        uint y = (id) - x * oN;                                         \
        uint k = x;                                                     \
        uint j = k / iW;                                                \
        k -= j * iW;                                                    \
        uint i = j / iH;                                                \
        j -= i * iH;                                                    \
        (weight)[x][y] = (input)[i][j][k] * (oerr)[y];                  \
        if(x == 0)                                                      \
            (bias)[y] = (oerr)[y];                                      \
    }                                                                   \
}


layout(binding = 0) buffer b0
{
    LeNet5 lenet;
};

layout(binding = 1) buffer b1
{
    Feature feature[];
};

layout(binding = 2) buffer b2
{
    Feature error[];
};

layout(binding = 3) buffer b3
{
    LeNet5 delta[];
};

#endif /* lenet_comp_h */
